/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: BUSL-1.1
*/

package watcher

import (
	"context"
	"crypto/x509"
	"encoding/asn1"
	"encoding/json"
	"io"
	"net/http"
	"net/http/httptest"
	"path/filepath"
	"sync"
	"testing"

	"github.com/edgelesssys/constellation/v2/internal/atls"
	"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
	"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
	"github.com/edgelesssys/constellation/v2/internal/config"
	"github.com/edgelesssys/constellation/v2/internal/constants"
	"github.com/edgelesssys/constellation/v2/internal/file"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m,
		// https://github.com/census-instrumentation/opencensus-go/issues/1262
		goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
		goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"),
	)
}

func TestNewUpdateableValidator(t *testing.T) {
	testCases := map[string]struct {
		variant  variant.Variant
		config   config.AttestationCfg
		snpCerts *stubSnpCerts
		wantErr  bool
	}{
		"azure": {
			variant: variant.AzureSEVSNP{},
			config:  config.DefaultForAzureSEVSNP(),
			snpCerts: &stubSnpCerts{
				ask: &x509.Certificate{},
				ark: &x509.Certificate{},
			},
		},
		"gcp": {
			variant: variant.GCPSEVES{},
			config:  &config.GCPSEVES{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
		},
		"qemu": {
			variant: variant.QEMUVTPM{},
			config:  &config.QEMUVTPM{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
		},
		"no file": {
			variant: variant.AzureSEVSNP{},
			wantErr: true,
		},
		"invalid provider": {
			variant: fakeOID{1, 3, 9900, 9999, 9999},
			config:  &config.GCPSEVES{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			handler := file.NewHandler(afero.NewMemMapFs())
			if tc.config != nil {
				require.NoError(handler.WriteJSON(
					filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
					tc.config,
				))
			}

			_, err := NewValidator(
				logger.NewTest(t),
				tc.variant,
				handler,
				tc.snpCerts,
			)
			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
			}
		})
	}
}

type stubSnpCerts struct {
	ask *x509.Certificate
	ark *x509.Certificate
}

func (s *stubSnpCerts) SevSnpCerts() (ask *x509.Certificate, ark *x509.Certificate) {
	return s.ask, s.ark
}

func TestUpdate(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	handler := file.NewHandler(afero.NewMemMapFs())

	// create server
	validator := &Updatable{
		log:         logger.NewTest(t),
		variant:     variant.Dummy{},
		fileHandler: handler,
	}

	// Update should fail if the file does not exist
	assert.Error(validator.Update())

	// write measurement config
	require.NoError(handler.WriteJSON(
		filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
		&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
	))

	// call update once to initialize the server's validator
	require.NoError(validator.Update())

	// create tls config and start the server
	serverConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator})
	require.NoError(err)
	server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		_, _ = io.WriteString(w, "hello")
	}))
	server.TLS = serverConfig
	server.StartTLS()
	defer server.Close()

	// test connection to server
	clientOID := variant.Dummy{}
	resp, err := testConnection(t.Context(), require, server.URL, clientOID)
	require.NoError(err)
	defer resp.Body.Close()
	body, err := io.ReadAll(resp.Body)
	require.NoError(err)
	assert.EqualValues("hello", body)

	// update the server's validator
	validator.variant = variant.QEMUVTPM{}
	require.NoError(validator.Update())

	// client connection should fail now, since the server's validator expects a different OID from the client
	resp, err = testConnection(t.Context(), require, server.URL, clientOID)
	if err == nil {
		defer resp.Body.Close()
	}
	assert.Error(err)
}

func TestOIDConcurrency(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	handler := file.NewHandler(afero.NewMemMapFs())
	require.NoError(handler.WriteJSON(
		filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
		&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
	))

	// create server
	validator := &Updatable{
		log:         logger.NewTest(t),
		variant:     variant.Dummy{},
		fileHandler: handler,
	}

	// call update once to initialize the server's validator
	require.NoError(validator.Update())

	var wg sync.WaitGroup
	wg.Add(2 * 20)
	for i := 0; i < 20; i++ {
		go func() {
			defer wg.Done()
			assert.NoError(validator.Update())
		}()
		go func() {
			defer wg.Done()
			validator.OID()
		}()
	}
	wg.Wait()
}

func TestUpdateConcurrency(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	handler := file.NewHandler(afero.NewMemMapFs())
	validator := &Updatable{
		log:         logger.NewTest(t),
		fileHandler: handler,
		variant:     variant.Dummy{},
	}
	require.NoError(handler.WriteJSON(
		filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
		&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
		file.OptNone,
	))

	var wg sync.WaitGroup

	for i := 0; i < 10; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			assert.NoError(validator.Update())
		}()
	}

	wg.Wait()
}

func testConnection(ctx context.Context, require *require.Assertions, url string, oid variant.Getter) (*http.Response, error) {
	clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{oid}, nil)
	require.NoError(err)
	client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
	require.NoError(err)
	return client.Do(req)
}

type fakeIssuer struct {
	variant.Getter
}

func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
	return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}

type fakeOID asn1.ObjectIdentifier

func (o fakeOID) OID() asn1.ObjectIdentifier {
	return asn1.ObjectIdentifier(o)
}

func (o fakeOID) String() string {
	return o.OID().String()
}

func (o fakeOID) Equal(other variant.Getter) bool {
	return o.OID().Equal(other.OID())
}

type fakeDoc struct {
	UserData []byte
	Nonce    []byte
}
